# -*- coding: utf-8 -*-
"""
Created on Tue Apr  5 11:26:05 2022

"""
import pickle
#import networkx
import numpy as np
import seaborn as sns
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler,MinMaxScaler
import torch
import torch.utils.data as utils
import random
import pandas as pd


def normalize_rows(mx):
    """rows-normalize dense matrix"""
    """input is a numpy array""" 
    colsum = mx.sum(axis=1)
    r_inv = np.power(colsum, -1).flatten()
    r_inv[np.isinf(r_inv)] = 0.
    r_mat_inv = np.diag(r_inv)
    mx = r_mat_inv.dot(mx)
    return mx


#%%

def create_edgelist(node_coord,data_path,res=30):
    node_coord[:,1]=node_coord[:,1]-node_coord[:,1].min()+0.1
    node_coord[:,2]=node_coord[:,2]-node_coord[:,2].min()+0.1
    
    x=np.arange(0,330,res)
    y=np.arange(0,330,res)
    count=np.zeros((x.shape[0]-1,x.shape[0]-1))
    edge_list=[]
    for i in range(len(x)-1):
        for j in range(len(y)-1):
            for k in range(node_coord.shape[0]):
                if node_coord[:,1][k]>x[i] and node_coord[:,1][k]<x[i+1]:
                    if node_coord[:,2][k]>y[j] and node_coord[:,2][k]<y[j+1]:
                        count[i,j]=count[i,j]+1
                        edge_list.append([node_coord[:,0][k],i,j])
    
    edge_list=np.array(edge_list)
    node_order=np.argsort(edge_list[:,0])
    edge_list=edge_list[node_order]
    
    with open(data_path+'edge_list.pkl','wb') as f:
        pickle.dump(edge_list,f)
        
    return edge_list,count
    

#B_T_N+B_T->B_N_T_F
def org_seri(temp_time,temp_input):
    seq1=[]
    for j in range(temp_input.shape[2]):
        #T_F
        seq1.append(np.stack([temp_time,temp_input[:,:,j]],axis=2))
    seq1=np.stack(seq1,axis=1)
    return seq1


def g2cnn(array,edge_list,count,paddings):
    '''
    array shape is B_T_F_N
    '''
    if type(paddings)!=list:
        temp=paddings*np.ones((array.shape[0],array.shape[1],
                               array.shape[2],count.shape[0],count.shape[0]))
        
        for i in range(edge_list.shape[0]):
            temp[:,:,:,int(edge_list[i,1]),int(edge_list[i,2])]=array[:,:,:,i]
    else:
        temp=[]
        for j,padding in enumerate(paddings):
            if type(padding)==str:
                if padding=='same':
                    temp1=same_padding(array, count, j, padding)
                    temp.append(temp1)
            else:
                temp1=int_padding(array,count,padding)
                temp.append(temp1)
        
        temp=np.concatenate(temp,axis=2)
        for i in range(edge_list.shape[0]):
            temp[:,:,:,int(edge_list[i,1]),int(edge_list[i,2])]=array[:,:,:,i]
            
    return temp


def same_padding(array,count,j,padding):
    '''
    array shape is B_T_F_N
    temp shape is B_T_1_H_W
    '''
    temp=array[:,:,j,0]
    temp=temp[:,:,np.newaxis,np.newaxis,np.newaxis]
    temp=np.tile(temp,(1,1,1,count.shape[0],count.shape[1]))
    return temp

def int_padding(array,count,padding):
    '''
    array shape is B_T_F_N
    temp shape is B_T_1_H_W
    '''
    temp=padding*np.ones((array.shape[0],array.shape[1],
                          1,count.shape[0],count.shape[0]))
    return temp


def stand_heads(train_label, train_label_mask, val_label,val_label_mask):
    temp1=train_label.flatten().reshape(-1,1)
    temp1_mask=train_label_mask.flatten().reshape(-1,1)
    mmin=temp1[temp1_mask].min()
    mmax=temp1[temp1_mask].max()
    dev=mmax-mmin
    
    train_label=(train_label-mmin)/dev
    val_label=(val_label-mmin)/dev

    return train_label, val_label,mmin,dev


def stand_concs(train_label1,train_label2, train_label2_mask,
                test_label1,test_label2,test_label2_mask,
                unbias_train_label1,unbias_train_label2):
    """
    train_label1 shape is B_T_F
    train_label2 shape is B_T_N_F
    """
    
    temp1=unbias_train_label1.reshape(-1,unbias_train_label1.shape[-1])
    temp2=unbias_train_label2.reshape(-1,unbias_train_label2.shape[-1])
    temp2_=temp2[temp2>=0]
    temp2[temp2<temp2_.min()]=temp2_.min()
    mmin_list=[]
    dev_list=[]
    for i in range(train_label2.shape[-1]):
        mmin1=temp1[:,i].min()
        mmax1=temp1[:,i].max()
        mmin2=temp2[:,i].min()
        mmax2=temp2[:,i].max()
        mmax=max(mmax1,mmax2)
        mmin=min(mmin1,mmin2)
        dev=mmax-mmin
        # mmin=mmin+dev/2
        # dev=dev/2
        
        train_label1[:,:,i]=(train_label1[:,:,i]-mmin)/dev
        train_label2[:,:,:,i]=(train_label2[:,:,:,i]-mmin)/dev
        test_label1[:,:,i]=(test_label1[:,:,i]-mmin)/dev
        test_label2[:,:,:,i]=(test_label2[:,:,:,i]-mmin)/dev
        mmin_list.append(mmin)
        dev_list.append(dev)
    return train_label1,train_label2,\
             test_label1,test_label2,\
             mmin_list,dev_list



def split_array(time_window,array,gap=1):
    '''
    T_N
    '''
    time_len=array.shape[0]
    new_array=[]
    for i in range(0,time_len-time_window+gap,gap):
        if i+time_window>time_len:
            new_array.append(array[time_len-time_window:time_len])
        else:
            new_array.append(array[i:i+time_window])

    new_array=np.stack(new_array,axis=0)
    return new_array



#%%
def PrepareHeadDataset(time_window=30,res=30,train_ratio=0.6):
    
    #读入数据
    path='cnn_data/'
    with open(path+'input_data_head.pkl','rb') as f:
        input_data=pickle.load(f)

    with open(path+'input_data_name_head.pkl','rb') as f:
        input_data_name=pickle.load(f)

    with open(path+'node_coord.pkl','rb') as f:
        node_coord=pickle.load(f)

    with open(path+'heads_day.pkl','rb') as f:
        heads_day=pickle.load(f)

    with open(path+'heads_name_day.pkl','rb') as f:
        heads_name_day=pickle.load(f)

    with open(path+'heads_time_day.pkl','rb') as f:
        heads_time_day=pickle.load(f)
        

    edge_list,count=create_edgelist(node_coord,path,res)
    train_index=int(time_window*train_ratio)

    heads_day_mask=heads_day!=0
    

    #B_T_N
    input_data=split_array(time_window,input_data)
    heads_day=split_array(time_window,heads_day)
    heads_day_mask=split_array(time_window,heads_day_mask)


    train_input=input_data[:,:train_index,1:]
    train_time=input_data[:,:train_index,0]
    train_label=heads_day[:,:train_index,:]
    train_label_mask=heads_day_mask[:,:train_index,:]
    
    val_input=input_data[:,train_index:,1:]
    val_time=input_data[:,train_index:,0]
    val_label=heads_day[:,train_index:,:]
    val_label_mask=heads_day_mask[:,train_index:,:]


    scaler2=MinMaxScaler().fit(train_time.flatten().reshape(-1,1))
    train_time=scaler2.transform(train_time.flatten().reshape(-1,1)).reshape(train_time.shape)
    val_time=scaler2.transform(val_time.flatten().reshape(-1,1)).reshape(val_time.shape)
    

    scaler1 = StandardScaler().fit(train_input.flatten().reshape(-1,1))
    train_input=scaler1.transform(train_input.flatten().reshape(-1,1)).reshape(train_input.shape)
    val_input=scaler1.transform(val_input.flatten().reshape(-1,1)).reshape(val_input.shape)


    train_label, val_label,mmin,dev=stand_heads(train_label, train_label_mask, val_label,val_label_mask)
    train_label=train_label.transpose(0,2,1)
    train_label_mask=train_label_mask.transpose(0,2,1)
    train_label=np.concatenate((train_label,
                           -10*np.ones((train_label.shape[0],
                               (train_input.shape[2]-train_label.shape[1]),
                                        train_label.shape[2]))),axis=1)
    
    train_label_mask=np.concatenate((train_label_mask,
                                -10*np.ones((train_label_mask.shape[0],
                                    (train_input.shape[2]-train_label_mask.shape[1]),
                                             train_label_mask.shape[2]))>0),axis=1)
    
    val_label=val_label.transpose(0,2,1)
    val_label_mask=val_label_mask.transpose(0,2,1)
    val_label=np.concatenate((val_label,
                           -10*np.ones((val_label.shape[0],
                               (val_input.shape[2]-val_label.shape[1]),
                                        val_label.shape[2]))),axis=1)
    
    val_label_mask=np.concatenate((val_label_mask,
                                -10*np.ones((val_label_mask.shape[0],
                                    (val_input.shape[2]-val_label_mask.shape[1]),
                                             val_label_mask.shape[2]))>0),axis=1)
    

    train_label=train_label[:,:,:,np.newaxis]
    train_label_mask=train_label_mask[:,:,:,np.newaxis]
    val_label=val_label[:,:,:,np.newaxis]
    val_label_mask=val_label_mask[:,:,:,np.newaxis]
    
    

    train_input=org_seri(train_time,train_input)
    val_input=org_seri(val_time,val_input)
    
    

    train_input=train_input.transpose(0,2,3,1)
    train_label=train_label.transpose(0,2,3,1)
    train_label_mask=train_label_mask.transpose(0,2,3,1)
    val_input=val_input.transpose(0,2,3,1)
    val_label=val_label.transpose(0,2,3,1)
    val_label_mask=val_label_mask.transpose(0,2,3,1)
    
    #array,B_T_C_H_W
    train_input=g2cnn(train_input,edge_list,count)
    train_label=g2cnn(train_label,edge_list,count)
    train_label_mask=g2cnn(train_label_mask,edge_list,count)
    val_input=g2cnn(val_input,edge_list,count)
    val_label=g2cnn(val_label,edge_list,count)
    val_label_mask=g2cnn(val_label_mask,edge_list,count)



    #B_N_T_F
    train_input=torch.Tensor(train_input)
    train_label=torch.Tensor(train_label)
    train_label_mask=torch.Tensor(train_label_mask)>0
    val_input=torch.Tensor(val_input)
    val_label=torch.Tensor(val_label)
    val_label_mask=torch.Tensor(val_label_mask)>0
    
    

    index = [i for i in range(train_input.shape[0])]
    random.shuffle(index)
    train_input=train_input[index]
    train_label=train_label[index]
    train_label_mask=train_label_mask[index]
    
    index = [i for i in range(val_input.shape[0])]
    random.shuffle(index)
    val_input=val_input[index]
    val_label=val_label[index]
    val_label_mask=val_label_mask[index]
    

    # train_dataset = utils.TensorDataset(train_input, train_label)
    # val_dataset = utils.TensorDataset(val_input, val_label)

    # train_dataloader = utils.DataLoader(train_dataset, batch_size = BATCH_SIZE)
    # val_dataloader = utils.DataLoader(val_dataset, batch_size = BATCH_SIZE)
    
    return train_input,train_label,train_label_mask,\
            val_input,val_label,val_label_mask,\
            mmin,dev,heads_name_day


#%%
time_window=1
res=30
pred_day=1
buffer=10
gap=1
test_day=60
val_ratio=0.2
def PrepareConcDataset(time_window=1,res=30,pred_day=1,buffer=10,
                       gap=1,test_day=60,val_ratio=0.2,flag='train'):

    path='../cnn_data/'
    
    with open(path+'input_data_conc.pkl','rb') as f:
        input_data_conc=pickle.load(f)
        
    with open(path+'input_data_name_conc.pkl','rb') as f:
        input_data_name_conc=pickle.load(f)
        
    with open(path+'input_H_conc.pkl','rb') as f:
        input_H_conc=pickle.load(f)

    input_H_conc[:,0]=input_H_conc[:,0]-input_H_conc[:,0].min()
    input_H_conc=np.delete(input_H_conc,0,axis=1)
    
    with open(path+'conc_day_6.pkl','rb') as f:
        conc_day=pickle.load(f)
        
    with open(path+'conc_day_coe.pkl','rb') as f:
        conc_day_coe=pickle.load(f)
        
    with open(path+'conc_week.pkl','rb') as f:
        conc_week=pickle.load(f)
        
    with open(path+'node_coord.pkl','rb') as f:
        node_coord=pickle.load(f)


    edge_list,count=create_edgelist(node_coord,path,res)

    conc_week_mask=conc_week>=0
    
    test_index=test_day

    train_input_flux=input_data_conc[:-test_index]
    train_input_H=input_H_conc[:-test_index]
    train_label1=conc_day[:-test_index]
    train_label1_coe=conc_day_coe[:-test_index]
    train_label2=conc_week[:-test_index]
    train_label2_mask=conc_week_mask[:-test_index]
    
    test_input_flux=input_data_conc[-test_index-buffer:]
    test_input_H=input_H_conc[-test_index-buffer:]
    test_label1=conc_day[-test_index:]
    test_label1_coe=conc_day_coe[-test_index:]
    test_label2=conc_week[-test_index:]
    test_label2_mask=conc_week_mask[-test_index:]
    
    unbias_train_input_flux=train_input_flux[:,1:]
    unbias_train_input_H=train_input_H.copy()
    unbias_train_time=train_input_flux[:,0:1]
    unbias_train_label1=train_label1
    unbias_train_label2=train_label2
  


    train_input_flux=split_array(time_window+buffer,train_input_flux,gap)
    train_input_H=split_array(time_window+buffer,train_input_H,gap)
    train_label1=split_array(time_window+buffer,train_label1,gap)
    train_label1_coe=split_array(time_window+buffer,train_label1_coe,gap)
    train_label2=split_array(time_window+buffer,train_label2,gap)
    train_label2_mask=split_array(time_window+buffer,train_label2_mask,gap)
    test_input_flux=split_array(time_window+buffer,test_input_flux,gap)
    test_input_H=split_array(time_window+buffer,test_input_H,gap)
    test_label1=split_array(time_window,test_label1,gap)
    test_label1_coe=split_array(time_window,test_label1_coe,gap)
    test_label2=split_array(time_window,test_label2,gap)
    test_label2_mask=split_array(time_window,test_label2_mask,gap)



    train_time=train_input_flux[:,:,0]
    train_input_flux=train_input_flux[:,:,1:]
    train_label1=train_label1[:,buffer:,:]
    train_label1_coe=train_label1_coe[:,buffer:,:]
    train_label2=train_label2[:,buffer:,:]
    train_label2_mask=train_label2_mask[:,buffer:,:]

    test_time=test_input_flux[:,:,0]
    test_input_flux=test_input_flux[:,:,1:]
    


    scaler2=MinMaxScaler().fit(unbias_train_time.reshape(-1,1))
    train_time=scaler2.transform(train_time.flatten().reshape(-1,1)).reshape(train_time.shape)
    test_time=scaler2.transform(test_time.flatten().reshape(-1,1)).reshape(test_time.shape)
    

    scaler1 = StandardScaler().fit(unbias_train_input_flux.reshape(-1,1))
    train_input_flux=scaler1.transform(train_input_flux.flatten().reshape(-1,1)).reshape(train_input_flux.shape)
    test_input_flux=scaler1.transform(test_input_flux.flatten().reshape(-1,1)).reshape(test_input_flux.shape)
    scaler3 = StandardScaler().fit(unbias_train_input_H.reshape(-1,1))
    train_input_H=scaler3.transform(train_input_H.flatten().reshape(-1,1)).reshape(train_input_H.shape)
    test_input_H=scaler3.transform(test_input_H.flatten().reshape(-1,1)).reshape(test_input_H.shape)



    #B_T_N_F
    train_label1,train_label2, \
    test_label1,test_label2,\
    mmin_list,dev_list=stand_concs(train_label1,train_label2, train_label2_mask,
                                   test_label1,test_label2,test_label2_mask,
                                   unbias_train_label1,unbias_train_label2)

    train_label1=train_label1[:,np.newaxis,:,:]
    train_label1_coe=train_label1_coe[:,:,:,np.newaxis]
    test_label1=test_label1[:,np.newaxis,:,:]
    test_label1_coe=test_label1_coe[:,:,:,np.newaxis]
    

    train_label1_coe=np.concatenate((train_label1_coe,
                                     0*np.ones((train_label1_coe.shape[0],
                                                train_label1_coe.shape[1],
                                                (train_input_flux.shape[2]-train_label1_coe.shape[2]),
                                                train_label1_coe.shape[3]))
                                     ),axis=2).transpose(0,2,1,3)
    
    train_label2=np.concatenate((train_label2,
                           -10*np.ones((train_label2.shape[0],
                                        train_label2.shape[1],
                                        (train_input_flux.shape[2]-train_label2.shape[2]),
                                        train_label2.shape[3]))
                           ),axis=2).transpose(0,2,1,3)
    
    train_label2_mask=np.concatenate((train_label2_mask,
                                -10*np.ones((train_label2_mask.shape[0],
                                             train_label2_mask.shape[1],
                                             (train_input_flux.shape[2]-train_label2_mask.shape[2]),
                                             train_label2_mask.shape[3]))>0
                                ),axis=2).transpose(0,2,1,3)
    

    test_label2=np.concatenate((test_label2,
                           -10*np.ones((test_label2.shape[0],
                                        test_label2.shape[1],
                                        (test_input_flux.shape[2]-test_label2.shape[2]),
                                        test_label2.shape[3]))
                           ),axis=2).transpose(0,2,1,3)
    
    test_label2_mask=np.concatenate((test_label2_mask,
                                -10*np.ones((test_label2_mask.shape[0],
                                             test_label2_mask.shape[1],
                                             (test_input_flux.shape[2]-test_label2_mask.shape[2]),
                                             test_label2_mask.shape[3]))>0
                                ),axis=2).transpose(0,2,1,3)
    
    test_label1_coe=np.concatenate((test_label1_coe,
                             0*np.ones((test_label1_coe.shape[0],
                                        test_label1_coe.shape[1],
                                        (test_input_flux.shape[2]-test_label1_coe.shape[2]),
                                        test_label1_coe.shape[3]))
                           ),axis=2).transpose(0,2,1,3)
    


    train_input_flux=org_seri(train_time,train_input_flux)
    test_input_flux=org_seri(test_time,test_input_flux)
    train_input_H=org_seri(train_time,train_input_H)[:,:,:,-1:]
    test_input_H=org_seri(test_time,test_input_H)[:,:,:,-1:]


    train_input=np.concatenate([train_input_flux,train_input_H],axis=-1)
    # train_input=train_input_flux
    train_label1=train_label1[...,0:1]
    train_label1_coe=train_label1_coe[...,0:1]
    train_label2=train_label2[...,0:1]
    train_label2_mask=train_label2_mask[...,0:1]
    
    test_input=np.concatenate([test_input_flux,test_input_H],axis=-1)
    # test_input=test_input_flux
    test_label1=test_label1[...,0:1]
    test_label1_coe=test_label1_coe[...,0:1]
    test_label2=test_label2[...,0:1]
    test_label2_mask=test_label2_mask[...,0:1]
    
    mmin_list=np.array(mmin_list[0:1])
    dev_list=np.array(dev_list[0:1])
    


    #B_N_T_F->B_T_F_N
    train_input=train_input.transpose(0,2,3,1)
    train_label1=train_label1.transpose(0,2,3,1)
    train_label1_coe=train_label1_coe.transpose(0,2,3,1)
    train_label2=train_label2.transpose(0,2,3,1)
    train_label2_mask=train_label2_mask.transpose(0,2,3,1)
    test_input=test_input.transpose(0,2,3,1)
    test_label1=test_label1.transpose(0,2,3,1)
    test_label1_coe=test_label1_coe.transpose(0,2,3,1)
    test_label2=test_label2.transpose(0,2,3,1)
    test_label2_mask=test_label2_mask.transpose(0,2,3,1)

    

    # train_input=g2cnn(train_input,edge_list,count,paddings=(0-scaler1.mean_)/np.sqrt(scaler1.var_))
    # train_input=g2cnn(train_input,edge_list,count,paddings=['same',(0-scaler1.mean_)/np.sqrt(scaler1.var_)])
    train_input=g2cnn(train_input,edge_list,count,paddings=['same',(0-scaler1.mean_)/np.sqrt(scaler1.var_),
                                                            (0-scaler3.mean_)/np.sqrt(scaler3.var_)])
    train_label1=train_label1[:,:,:,:,np.newaxis]
    train_label1_coe=g2cnn(train_label1_coe,edge_list,count,paddings=0)
    train_label2=g2cnn(train_label2,edge_list,count,paddings=-10)
    train_label2_mask=g2cnn(train_label2_mask,edge_list,count,paddings=-10)
    # test_input=g2cnn(test_input,edge_list,count,paddings=(0-scaler1.mean_)/np.sqrt(scaler1.var_))
    # test_input=g2cnn(test_input,edge_list,count,paddings=['same',(0-scaler1.mean_)/np.sqrt(scaler1.var_)])
    test_input=g2cnn(test_input,edge_list,count,paddings=['same',(0-scaler1.mean_)/np.sqrt(scaler1.var_),
                                                          (0-scaler3.mean_)/np.sqrt(scaler3.var_)])
    test_label1=test_label1[:,:,:,:,np.newaxis]
    test_label1_coe=g2cnn(test_label1_coe,edge_list,count,paddings=0)
    test_label2=g2cnn(test_label2,edge_list,count,paddings=-10)
    test_label2_mask=g2cnn(test_label2_mask,edge_list,count,paddings=-10)

    

    #B_T_C_H_W
    train_input=torch.Tensor(train_input)
    train_label1=torch.Tensor(train_label1)
    train_label1_coe=torch.Tensor(train_label1_coe)
    train_label2=torch.Tensor(train_label2)
    train_label2_mask=torch.Tensor(train_label2_mask)>0
    
    test_input=torch.Tensor(test_input)
    test_label1=torch.Tensor(test_label1)
    test_label1_coe=torch.Tensor(test_label1_coe)
    test_label2=torch.Tensor(test_label2)
    test_label2_mask=torch.Tensor(test_label2_mask)>0
    
    random.seed(1000)
    index = [i for i in range(train_input.shape[0])]
    random.shuffle(index)
    train_input=train_input[index]
    train_label1=train_label1[index]
    train_label1_coe=train_label1_coe[index]
    train_label2=train_label2[index]
    train_label2_mask=train_label2_mask[index]
    
    if flag=='train':

        val_index=int(train_input.shape[0]*val_ratio)
        val_input=train_input[:val_index]
        val_label1=train_label1[:val_index]
        val_label1_coe=train_label1_coe[:val_index]
        val_label2=train_label2[:val_index]
        val_label2_mask=train_label2_mask[:val_index]
        
        train_input=train_input[val_index:]
        train_label1=train_label1[val_index:]
        train_label1_coe=train_label1_coe[val_index:]
        train_label2=train_label2[val_index:]
        train_label2_mask=train_label2_mask[val_index:]
        

        # train_dataset = utils.TensorDataset(train_input, train_label)
        # val_dataset = utils.TensorDataset(val_input, val_label)
    

        # train_dataloader = utils.DataLoader(train_dataset, batch_size = BATCH_SIZE)
        # val_dataloader = utils.DataLoader(val_dataset, batch_size = BATCH_SIZE)
        
        return train_input,train_label1,train_label1_coe,train_label2,train_label2_mask,\
                val_input,val_label1,val_label1_coe,val_label2,val_label2_mask,\
                test_input,test_label1,test_label1_coe,test_label2,test_label2_mask,\
                mmin_list,dev_list,input_data_name_conc[1:19],edge_list,scaler2

    else:
        return train_input,train_label1,train_label1_coe,train_label2,train_label2_mask,\
                test_input,test_label1,test_label1_coe,test_label2,test_label2_mask,\
                mmin_list,dev_list,input_data_name_conc[1:19],edge_list,scaler2
